module(name = "xla")

##############################################################
# Bazel module dependencies

bazel_dep(name = "abseil-cpp", version = "20250814.0", repo_name = "com_google_absl")
bazel_dep(name = "abseil-py", version = "2.1.0", repo_name = "absl_py")
bazel_dep(name = "bazel_features", version = "1.36.0")
bazel_dep(name = "bazel_skylib", version = "1.8.1")
bazel_dep(name = "boringssl", version = "0.20250818.0")
bazel_dep(name = "curl", version = "8.11.0")
bazel_dep(name = "google_benchmark", version = "1.8.5", repo_name = "com_google_benchmark")
bazel_dep(name = "googletest", version = "1.17.0", repo_name = "com_google_googletest_upstream")
bazel_dep(name = "xla_googletest_wrapper", version = "1.0", repo_name = "com_google_googletest")
bazel_dep(name = "grpc", version = "1.74.1", repo_name = "com_github_grpc_grpc")
bazel_dep(name = "gutil", version = "20250502.0", repo_name = "com_google_gutil")
bazel_dep(name = "jsoncpp", version = "1.9.6", repo_name = "jsoncpp_git")
bazel_dep(name = "or-tools", version = "9.12", repo_name = "com_google_ortools")
bazel_dep(name = "platforms", version = "1.0.0")
bazel_dep(name = "protobuf", version = "32.1", repo_name = "com_google_protobuf")
bazel_dep(name = "pybind11_abseil", version = "202402.0")
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
bazel_dep(name = "pybind11_protobuf", version = "0.0.0-20250210-f02a2b7")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_cc", version = "0.2.0")
bazel_dep(name = "rules_license", version = "1.0.0")
bazel_dep(name = "rules_python", version = "1.6.0")
bazel_dep(name = "rules_shell", version = "0.6.1")
bazel_dep(name = "snappy", version = "1.2.1")
bazel_dep(name = "zlib", version = "1.3.1.bcr.5")

# Only for compatibility, not directly used, change repo_name to None after upgrading Bazel to latest 7.x
bazel_dep(name = "eigen", version = "4.0.0-20241125.bcr.3", repo_name = "DO_NOT_USE_eigen")
bazel_dep(name = "grpc-java", version = "1.75.0", repo_name = "DO_NOT_USE_grpc_java")

# TODO: publish an official version of rules_ml_toolchain to BCR
bazel_dep(name = "rules_ml_toolchain")

# To calculate integrity:
# wget -O temp_module_archive.tar.gz <archive URL>
# HASH=$(openssl dgst -sha256 -binary temp_module_archive.tar.gz | openssl base64 -A)
# echo "sha256-${HASH}"
archive_override(
    module_name = "rules_ml_toolchain",
    integrity = "sha256-seXjBtixED5zubd438Op4GnSBmRDegMkaiNXJJYrXJQ=",
    strip_prefix = "rules_ml_toolchain-484235be45e6843db962c45d08fe4b2b65a6a24c",
    urls = ["https://github.com/google-ml-infra/rules_ml_toolchain/archive/484235be45e6843db962c45d08fe4b2b65a6a24c.tar.gz"],
)

# TODO: Upstream the patch?
single_version_override(
    module_name = "grpc",
    patch_strip = 1,
    patches = ["//third_party/grpc:grpc.patch"],
)

# TODO: Upstream those patch?
single_version_override(
    module_name = "abseil-cpp",
    patch_strip = 1,
    patches = [
        "//third_party/absl:btree.patch",
        "//third_party/absl:build_dll.patch",
        "//third_party/absl:endian.patch",
        "//third_party/absl:rules_cc.patch",
        "//third_party/absl:check_op.patch",
        "//third_party/absl:check_op_2.patch",
    ],
)

# Use an unreleased version of googletest
archive_override(
    module_name = "googletest",
    strip_prefix = "googletest-28e9d1f26771c6517c3b4be10254887673c94018",
    urls = ["https://github.com/google/googletest/archive/28e9d1f26771c6517c3b4be10254887673c94018.zip"],
)

local_path_override(
    module_name = "xla_googletest_wrapper",
    path = "third_party/xla_googletest_wrapper",
)

##############################################################
# C++ dependencies

# TODO: most of them can be Bazel modules, but we need a release strategy for them
third_party = use_extension("//third_party/extensions:third_party.bzl", "third_party_ext")
use_repo(
    third_party,
    "FXdiv",
    "XNNPACK",
    "cpuinfo",
    "cudnn_frontend_archive",
    "dlpack",
    "ducc",
    "eigen_archive",
    "farmhash_archive",
    "farmhash_gpu_archive",
    "gloo",
    "highwayhash",
    "hwloc",
    "implib_so",
    "llvm-raw",
    "llvm_openmp",
    "ml_dtypes_py",
    "mpitrampoline",
    "nanobind",
    "nvshmem",
    "onednn",
    "pthreadpool",
    "rocm_device_libs",
    "shardy",
    "slinky",
    "stablehlo",
    "triton",
)

##############################################################
# Python toolchain and pypi dependencies

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.defaults(python_version = "3.11")
python.toolchain(python_version = "3.11")
use_repo(python, "pythons_hub")

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.whl_mods(
    additive_build_content = """
load("@rules_cc//cc:cc_library.bzl", "cc_library")
cc_library(
    name = "numpy_headers_2",
    hdrs = glob(["site-packages/numpy/_core/include/**/*.h"]),
    strip_include_prefix="site-packages/numpy/_core/include/",
)
cc_library(
    name = "numpy_headers_1",
    hdrs = glob(["site-packages/numpy/core/include/**/*.h"]),
    strip_include_prefix="site-packages/numpy/core/include/",
)
cc_library(
    name = "numpy_headers",
    deps = [":numpy_headers_2", ":numpy_headers_1"],
)
""",
    hub_name = "pypi_mods",
    whl_name = "numpy",
)
pip.parse(
    extra_hub_aliases = {
        "numpy": ["numpy_headers"],
    },
    hub_name = "xla_pypi",
    python_version = "3.11",
    requirements_lock = "//:requirements_lock_3_11.txt",
    whl_modifications = {
        "@pypi_mods//:numpy.json": "numpy",
    },
)
use_repo(pip, "pypi_mods", pypi = "xla_pypi")

python_version_ext = use_extension("//third_party/extensions:python_version.bzl", "python_version_ext")
use_repo(python_version_ext, "python_version_repo")

##############################################################
# Other dependencies via module extensions

### TSL
tsl_extension = use_extension("//third_party/extensions:tsl.bzl", "tsl_extension")
use_repo(tsl_extension, "tsl")

### LLVM
llvm = use_extension("//third_party/extensions:llvm.bzl", "llvm_extension")
use_repo(llvm, "llvm-project")

### pybind
pybind11_internal_configure = use_extension(
    "@pybind11_bazel//:internal_configure.bzl",
    "internal_configure_extension",
)
use_repo(pybind11_internal_configure, "pybind11")

### RBE
rbe_config = use_extension("//third_party/extensions:rbe_config.bzl", "rbe_config_ext")
use_repo(rbe_config, "ml_build_config_platform")

remote_execution_configure = use_extension("//third_party/extensions:remote_execution_configure.bzl", "remote_execution_configure_ext")
use_repo(remote_execution_configure, "local_config_remote_execution")

### From @rules_ml_toolchain
cuda_configure = use_extension("@rules_ml_toolchain//third_party/extensions:cuda_configure.bzl", "cuda_configure_ext")
use_repo(cuda_configure, "local_config_cuda")

nccl_configure = use_extension("@rules_ml_toolchain//third_party/extensions:nccl_configure.bzl", "nccl_configure_ext")
use_repo(nccl_configure, "local_config_nccl")

sycl_configure = use_extension("@rules_ml_toolchain//third_party/extensions:sycl_configure.bzl", "sycl_configure_ext")
use_repo(sycl_configure, "local_config_sycl")

cuda_redist_init_ext = use_extension("@rules_ml_toolchain//third_party/extensions:cuda_redist_init.bzl", "cuda_redist_init_ext")
use_repo(cuda_redist_init_ext, "cuda_cudart")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64")

register_toolchains("@rules_ml_toolchain//cc:linux_x86_64_linux_x86_64_cuda")

register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64")

register_toolchains("@rules_ml_toolchain//cc:linux_aarch64_linux_aarch64_cuda")

### Other local config repos
rocm_configure = use_extension("//third_party/extensions:rocm_configure.bzl", "rocm_configure_ext")
use_repo(rocm_configure, "local_config_rocm")

tensorrt_configure = use_extension("//third_party/extensions:tensorrt_configure.bzl", "tensorrt_configure_ext")
use_repo(tensorrt_configure, "local_config_tensorrt")

pjrt_nightly_timestamp = use_extension("//build_tools/pjrt_wheels:nightly.bzl", "nightly_timestamp_repo_bzlmod")
use_repo(pjrt_nightly_timestamp, "nightly_timestamp")

pjrt_rc_number = use_extension("//build_tools/pjrt_wheels:release_candidate.bzl", "rc_number_repo_bzlmod")
use_repo(pjrt_rc_number, "rc_number")
